import tensorflow as tf
import model
import Inputs

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('testing', '', """ checkpoint file """)
tf.app.flags.DEFINE_string('finetune', '', """ finetune checkpoint file """)
tf.app.flags.DEFINE_integer('batch_size', "5", """ batch_size """)
tf.app.flags.DEFINE_float('learning_rate', "1e-3", """ initial lr """)
tf.app.flags.DEFINE_integer('max_steps', "20000", """ max_steps """)
tf.app.flags.DEFINE_integer('image_h', "360", """ image height """)
tf.app.flags.DEFINE_integer('image_w', "480", """ image width """)
tf.app.flags.DEFINE_integer('image_c', "3", """ image channel (RGB) """)
tf.app.flags.DEFINE_integer('num_class', "11", """ total class number """)
tf.app.flags.DEFINE_boolean('save_image', True, """ whether to save predicted image """)
tf.app.flags.DEFINE_string('train_img_dir', "CamVid/train", """ path to CamVid image """)
tf.app.flags.DEFINE_string('test_img_dir', "CamVid/test", """ path to CamVid test image """)
tf.app.flags.DEFINE_string('val_img_dir', "CamVid/val", """ path to CamVid val image """)
tf.app.flags.DEFINE_string('train_label_dir', "CamVid/trainannot", """ path to CamVid image """)
tf.app.flags.DEFINE_string('test_label_dir', "CamVid/testannot", """ path to CamVid test image """)
tf.app.flags.DEFINE_string('val_label_dir', "CamVid/valannot", """ path to CamVid val image """)
tf.app.flags.DEFINE_string('tfrecord_path', "CamVid/tfrecord", """ path to CamVid tfrecord dataset """)
tf.app.flags.DEFINE_string('train_tfrecord', "CamVid/tfrecord/trainrecord", """ path to CamVid train tfrecord """)
tf.app.flags.DEFINE_string('test_tfrecord', "CamVid/tfrecord/testrecord", """ path to CamVid test tfrecord """)
tf.app.flags.DEFINE_string('val_tfrecord', "CamVid/tfrecord/valrecord", """ path to CamVid val tfrecord """)
tf.app.flags.DEFINE_boolean('data_shuffle', True, """ dataset shuffle option """)


def checkArgs():
    if FLAGS.testing != '':
        print('The model is set to Testing')
        print("check point file: %s" % FLAGS.testing)
        print("CamVid testing dir: %s" % FLAGS.test_img_dir)
    elif FLAGS.finetune != '':
        print('The model is set to Finetune from ckpt')
        print("check point file: %s" % FLAGS.finetune)
        print("CamVid Image dir: %s" % FLAGS.train_img_dir)
        print("CamVid Val dir: %s" % FLAGS.val_img_dir)
    else:
        print('The model is set to Training')
        print("Max training Iteration: %d" % FLAGS.max_steps)
        print("Initial lr: %f" % FLAGS.learning_rate)
        print("CamVid Image dir: %s" % FLAGS.train_img_dir)
        print("CamVid Val dir: %s" % FLAGS.val_img_dir)

    print("Batch Size: %d" % FLAGS.batch_size)
    print("Log dir: %s" % FLAGS.log_dir)


def main(args):
    train_path_list = [FLAGS.train_img_dir, FLAGS.train_label_dir, FLAGS.train_tfrecord]
    val_path_list = [FLAGS.val_img_dir, FLAGS.val_label_dir, FLAGS.val_tfrecord]
    test_path_list = [FLAGS.test_img_dir, FLAGS.test_label_dir, FLAGS.test_tfrecord]
    checkArgs()
    Inputs.initial(FLAGS.tfrecord_path, train_path_list, val_path_list, test_path_list)
    if FLAGS.testing:
        model.test(FLAGS)
    elif FLAGS.finetune:
        model.training(FLAGS, is_finetune=True)
    else:
        model.training(FLAGS, is_finetune=False)


if __name__ == '__main__':
    tf.app.run()
